# @File    : painter.py
# 文件描述：绘图相关的核心代码，以便后期数据处理时直接复用
import os.path
import typing
import PyQt5.QtWidgets
import matplotlib
import numpy
from PyQt5.QtWidgets import QHBoxLayout, QGraphicsScene
from matplotlib import pyplot as plt
from matplotlib.backends.backend_qt import NavigationToolbar2QT
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
from matplotlib.figure import Figure
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import common
matplotlib.use("Qt5Agg")
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs
    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)
class MyCanvas(FigureCanvasQTAgg):
    def __init__(self,
                 width=10, heigh=10, dpi=100
                 ):
        self.fig = Figure(
            figsize=(width, heigh), dpi=dpi
        )
        super(MyCanvas, self).__init__(self.fig)  # 在父类激活self.fig， 否则不能显示图像（就是在画板上放置画布）
    @staticmethod
    def from_graphicsView(graphicsView, dpi=100):
        print("GraphicsView width")
        return MyCanvas(
            width=int(graphicsView.width() / dpi), heigh=int(graphicsView.height() / dpi), dpi=dpi
        )
    def put_on(self, graphicsView: PyQt5.QtWidgets.QGraphicsView,
               widget_as_toolbar: PyQt5.QtWidgets.QWidget
               ):
        hBoxLayout_fig_toolbar = QHBoxLayout(widget_as_toolbar)
        self.toolbar = NavigationToolbar2QT(self, widget_as_toolbar)
        hBoxLayout_fig_toolbar.addWidget(self.toolbar)
        graphic_scene = QGraphicsScene()
        graphic_scene.addWidget(self)
        graphicsView.setScene(graphic_scene)
class GraphicViewForMpl(FigureCanvasQTAgg):
    def __init__(self, parent: PyQt5.QtWidgets.QWidget):
        super(GraphicViewForMpl, self).__init__()
        self.__parent = parent
        self.resize(self.__parent.width(), self.__parent.height())
def plot3D_equal_on_Axes(ax: plt.Axes, X: numpy.ndarray, Y: numpy.ndarray, Z: numpy.ndarray, *args,
                         **kwargs):
    XYZ = numpy.array([X, Y, Z])
    minXYZ = XYZ.min(axis=1)
    maxXYZ = XYZ.max(axis=1)
    midXYZ = (minXYZ + maxXYZ) / 2
    factor = 1.1  # 边缘留空部分
    half_maxlenXYZ = ((maxXYZ - minXYZ).max() / 2) * factor
    XYZrange = numpy.array((midXYZ - half_maxlenXYZ, midXYZ + half_maxlenXYZ))
    ax.set_xlim(XYZrange[:, 0])
    ax.set_ylim(XYZrange[:, 1])
    ax.set_zlim(XYZrange[:, 2])
    ax.plot3D(X, Y, Z, *args, **kwargs)
def plot3D_equal(fig: plt.Figure
                 , X: numpy.ndarray, Y: numpy.ndarray, Z: numpy.ndarray, *args,
                 **kwargs) -> plt.Axes:
    ax = fig.gca(projection="3d")
    plot3D_equal_on_Axes(ax, X, Y, Z, *args, **kwargs)
    return ax
def plot_arc(fig: plt.Figure, P1, P2, rho, ds, axis):
    arc_pts, center = common.gen_arc_from_p1_p2_rho(
        P1, P2,
        rho,
        ds,
        axis
    )
    fig.clf()
    X, Y, Z = arc_pts[:, 0], arc_pts[:, 1], arc_pts[:, 2]
    XYZ = numpy.array((X, Y, Z)).T
    P1 = XYZ[0]
    P2 = XYZ[-1]
    ax = plot3D_equal(fig, X, Y, Z, )
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.plot3D(X, Y, Z, ".", label="Control Points")
    length = numpy.linalg.norm(P2 - P1)
    default_v = numpy.array([.20, .20, .30]) * length
    annotate_on_3d(ax, P1, default_v, "P1")
    annotate_on_3d(ax, P2, default_v, "P2")
    axis_start = (P1 + P2) / 2 - axis / 2
    ax.quiver(*axis_start, *(axis))  # 旋转轴方向的箭头
    ax.text(*axis_start, s="Axis Direction")
    ax.legend()
    return ax, arc_pts, center
def annotate_on_3d(ax: plt.Axes, p: numpy.ndarray, v: numpy.ndarray,
                   text: str):
    ax.plot3D(p[0], p[1], p[2], "o")
    start = p - v
    text_loc = p - 1.2 * v
    ax.quiver(*(numpy.append(start, v)))
    ax.text(*text_loc, **{"s": text})
def plot_traj(ax: plt.Axes, X, Y, Z, clear_axes=True, *args, **kwargs):
    if clear_axes:
        ax.cla()
    plot3D_equal_on_Axes(ax, X, Y, Z,
                         *args, **kwargs)
    ax.set_title("Control Points")
    ax.set_xlabel("X/mm")
    ax.set_ylabel("Y/mm")
    ax.set_zlabel("Z/mm")
    ax.legend()
def plot_traj_with_Traj(ax: plt.Axes, traj: common.Trajectory, clear_axes=True, *args, **kwargs):
    if not len(traj):
        return
    plot_traj(ax, traj.df[common.Trajectory.kX_mm],
              traj.df[common.Trajectory.kY_mm], traj.df[common.Trajectory.kZ_mm], clear_axes, *args,
              **kwargs)
def plot_traj_with_Recorder(ax: plt.Axes, recorder: common.RecorderBase, clear_axes=True, *args, **kwargs):
    plot_traj(ax, recorder.df[common.RecorderBase.k_x_mm, common.RecorderBase.k_y_mm, common.RecorderBase.k_z_mm],
              clear_axes, *args, **kwargs)
def plot_component_using_Record(ax: plt.Axes, record: common.RecorderBase, components: typing.Iterable[str],
                                column_as_x: str = common.RecorderBase.k_timestamp, clear_axes=True):
    if clear_axes:
        ax.cla()
        print("Ax cleared")
    for component in components:
        ax.plot(record.df[column_as_x], record.df[component], ".-", label=component)
    ax.set_title(os.path.split(record.autogen_path_method())[1])
    ax.legend()
